{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copyright 2019 NVIDIA Corporation. All Rights Reserved.\n",
    "#\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "#     http://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License.\n",
    "# =============================================================================="
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"http://developer.download.nvidia.com/compute/machine-learning/frameworks/nvidia_logo.png\" style=\"width: 90px; float: right;\">\n",
    "\n",
    "# Torch-TensorRT Getting Started - LeNet"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Overview\n",
    "\n",
    "In the practice of developing machine learning models, there are few tools as approachable as PyTorch for developing and experimenting in designing machine learning models. The power of PyTorch comes from its deep integration into Python, its flexibility and its approach to automatic differentiation and execution (eager execution). However, when moving from research into production, the requirements change and we may no longer want that deep Python integration and we want optimization to get the best performance we can on our deployment platform. In PyTorch 1.0, TorchScript was introduced as a method to separate your PyTorch model from Python, make it portable and optimizable. TorchScript uses PyTorch's JIT compiler to transform your normal PyTorch code which gets interpreted by the Python interpreter to an intermediate representation (IR) which can have optimizations run on it and at runtime can get interpreted by the PyTorch JIT interpreter. For PyTorch this has opened up a whole new world of possibilities, including deployment in other languages like C++. It also introduces a structured graph based format that we can use to do down to the kernel level optimization of models for inference.\n",
    "\n",
    "When deploying on NVIDIA GPUs TensorRT, NVIDIA's Deep Learning Optimization SDK and Runtime is able to take models from any major framework and specifically tune them to perform better on specific target hardware in the NVIDIA family be it an A100, TITAN V, Jetson Xavier or NVIDIA's Deep Learning Accelerator. TensorRT performs a couple sets of optimizations to achieve this. TensorRT fuses layers and tensors in the model graph, it then uses a large kernel library to select implementations that perform best on the target GPU. TensorRT also has strong support for reduced operating precision execution which allows users to leverage the Tensor Cores on Volta and newer GPUs as well as reducing memory and computation footprints on device.\n",
    "\n",
    "Torch-TensorRT is a compiler that uses TensorRT to optimize TorchScript code, compiling standard TorchScript modules into ones that internally run with TensorRT optimizations. This enables you to continue to remain in the PyTorch ecosystem, using all the great features PyTorch has such as module composability, its flexible tensor implementation, data loaders and more. Torch-TensorRT is available to use with both PyTorch and LibTorch."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Learning objectives\n",
    "\n",
    "This notebook demonstrates the steps for compiling a TorchScript module with Torch-TensorRT on a simple LeNet network. \n",
    "\n",
    "## Content\n",
    "1. [Requirements](#1)\n",
    "1. [Creating TorchScript modules](#2)\n",
    "1. [Compiling with Torch-TensorRT](#3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id=\"1\"></a>\n",
    "## 1. Requirements\n",
    "\n",
    "Follow the steps in `notebooks/README` to prepare a Docker container, within which you can run this notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com\n",
      "Collecting ipywidgets\n",
      "  Downloading ipywidgets-7.6.5-py2.py3-none-any.whl (121 kB)\n",
      "\u001b[K     |████████████████████████████████| 121 kB 12.7 MB/s eta 0:00:01\n",
      "\u001b[?25hRequirement already satisfied: traitlets>=4.3.1 in /opt/conda/lib/python3.8/site-packages (from ipywidgets) (5.1.1)\n",
      "Collecting jupyterlab-widgets>=1.0.0\n",
      "  Downloading jupyterlab_widgets-1.0.2-py3-none-any.whl (243 kB)\n",
      "\u001b[K     |████████████████████████████████| 243 kB 115.0 MB/s eta 0:00:01\n",
      "\u001b[?25hRequirement already satisfied: ipython-genutils~=0.2.0 in /opt/conda/lib/python3.8/site-packages (from ipywidgets) (0.2.0)\n",
      "Requirement already satisfied: nbformat>=4.2.0 in /opt/conda/lib/python3.8/site-packages (from ipywidgets) (5.1.3)\n",
      "Requirement already satisfied: ipykernel>=4.5.1 in /opt/conda/lib/python3.8/site-packages (from ipywidgets) (6.7.0)\n",
      "Requirement already satisfied: ipython>=4.0.0 in /opt/conda/lib/python3.8/site-packages (from ipywidgets) (7.31.0)\n",
      "Collecting widgetsnbextension~=3.5.0\n",
      "  Downloading widgetsnbextension-3.5.2-py2.py3-none-any.whl (1.6 MB)\n",
      "\u001b[K     |████████████████████████████████| 1.6 MB 122.5 MB/s eta 0:00:01\n",
      "\u001b[?25hRequirement already satisfied: tornado<7.0,>=4.2 in /opt/conda/lib/python3.8/site-packages (from ipykernel>=4.5.1->ipywidgets) (6.1)\n",
      "Requirement already satisfied: debugpy<2.0,>=1.0.0 in /opt/conda/lib/python3.8/site-packages (from ipykernel>=4.5.1->ipywidgets) (1.5.1)\n",
      "Requirement already satisfied: jupyter-client<8.0 in /opt/conda/lib/python3.8/site-packages (from ipykernel>=4.5.1->ipywidgets) (7.1.2)\n",
      "Requirement already satisfied: nest-asyncio in /opt/conda/lib/python3.8/site-packages (from ipykernel>=4.5.1->ipywidgets) (1.5.4)\n",
      "Requirement already satisfied: matplotlib-inline<0.2.0,>=0.1.0 in /opt/conda/lib/python3.8/site-packages (from ipykernel>=4.5.1->ipywidgets) (0.1.3)\n",
      "Requirement already satisfied: setuptools>=18.5 in /opt/conda/lib/python3.8/site-packages (from ipython>=4.0.0->ipywidgets) (59.5.0)\n",
      "Requirement already satisfied: backcall in /opt/conda/lib/python3.8/site-packages (from ipython>=4.0.0->ipywidgets) (0.2.0)\n",
      "Requirement already satisfied: pexpect>4.3 in /opt/conda/lib/python3.8/site-packages (from ipython>=4.0.0->ipywidgets) (4.8.0)\n",
      "Requirement already satisfied: pickleshare in /opt/conda/lib/python3.8/site-packages (from ipython>=4.0.0->ipywidgets) (0.7.5)\n",
      "Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /opt/conda/lib/python3.8/site-packages (from ipython>=4.0.0->ipywidgets) (3.0.24)\n",
      "Requirement already satisfied: jedi>=0.16 in /opt/conda/lib/python3.8/site-packages (from ipython>=4.0.0->ipywidgets) (0.18.1)\n",
      "Requirement already satisfied: decorator in /opt/conda/lib/python3.8/site-packages (from ipython>=4.0.0->ipywidgets) (5.1.0)\n",
      "Requirement already satisfied: pygments in /opt/conda/lib/python3.8/site-packages (from ipython>=4.0.0->ipywidgets) (2.11.1)\n",
      "Requirement already satisfied: parso<0.9.0,>=0.8.0 in /opt/conda/lib/python3.8/site-packages (from jedi>=0.16->ipython>=4.0.0->ipywidgets) (0.8.3)\n",
      "Requirement already satisfied: pyzmq>=13 in /opt/conda/lib/python3.8/site-packages (from jupyter-client<8.0->ipykernel>=4.5.1->ipywidgets) (22.3.0)\n",
      "Requirement already satisfied: entrypoints in /opt/conda/lib/python3.8/site-packages (from jupyter-client<8.0->ipykernel>=4.5.1->ipywidgets) (0.3)\n",
      "Requirement already satisfied: jupyter-core>=4.6.0 in /opt/conda/lib/python3.8/site-packages (from jupyter-client<8.0->ipykernel>=4.5.1->ipywidgets) (4.9.1)\n",
      "Requirement already satisfied: python-dateutil>=2.1 in /opt/conda/lib/python3.8/site-packages (from jupyter-client<8.0->ipykernel>=4.5.1->ipywidgets) (2.8.2)\n",
      "Requirement already satisfied: jsonschema!=2.5.0,>=2.4 in /opt/conda/lib/python3.8/site-packages (from nbformat>=4.2.0->ipywidgets) (4.4.0)\n",
      "Requirement already satisfied: attrs>=17.4.0 in /opt/conda/lib/python3.8/site-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets) (21.4.0)\n",
      "Requirement already satisfied: importlib-resources>=1.4.0 in /opt/conda/lib/python3.8/site-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets) (5.4.0)\n",
      "Requirement already satisfied: pyrsistent!=0.17.0,!=0.17.1,!=0.17.2,>=0.14.0 in /opt/conda/lib/python3.8/site-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets) (0.18.1)\n",
      "Requirement already satisfied: zipp>=3.1.0 in /opt/conda/lib/python3.8/site-packages (from importlib-resources>=1.4.0->jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets) (3.7.0)\n",
      "Requirement already satisfied: ptyprocess>=0.5 in /opt/conda/lib/python3.8/site-packages (from pexpect>4.3->ipython>=4.0.0->ipywidgets) (0.7.0)\n",
      "Requirement already satisfied: wcwidth in /opt/conda/lib/python3.8/site-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->ipython>=4.0.0->ipywidgets) (0.2.5)\n",
      "Requirement already satisfied: six>=1.5 in /opt/conda/lib/python3.8/site-packages (from python-dateutil>=2.1->jupyter-client<8.0->ipykernel>=4.5.1->ipywidgets) (1.16.0)\n",
      "Requirement already satisfied: notebook>=4.4.1 in /opt/conda/lib/python3.8/site-packages (from widgetsnbextension~=3.5.0->ipywidgets) (6.4.1)\n",
      "Requirement already satisfied: terminado>=0.8.3 in /opt/conda/lib/python3.8/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.12.1)\n",
      "Requirement already satisfied: prometheus-client in /opt/conda/lib/python3.8/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.13.0)\n",
      "Requirement already satisfied: jinja2 in /opt/conda/lib/python3.8/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (3.0.3)\n",
      "Requirement already satisfied: argon2-cffi in /opt/conda/lib/python3.8/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (21.3.0)\n",
      "Requirement already satisfied: nbconvert in /opt/conda/lib/python3.8/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (6.4.0)\n",
      "Requirement already satisfied: Send2Trash>=1.5.0 in /opt/conda/lib/python3.8/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (1.8.0)\n",
      "Requirement already satisfied: argon2-cffi-bindings in /opt/conda/lib/python3.8/site-packages (from argon2-cffi->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (21.2.0)\n",
      "Requirement already satisfied: cffi>=1.0.1 in /opt/conda/lib/python3.8/site-packages (from argon2-cffi-bindings->argon2-cffi->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (1.15.0)\n",
      "Requirement already satisfied: pycparser in /opt/conda/lib/python3.8/site-packages (from cffi>=1.0.1->argon2-cffi-bindings->argon2-cffi->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (2.21)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.8/site-packages (from jinja2->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (2.0.1)\n",
      "Requirement already satisfied: testpath in /opt/conda/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.5.0)\n",
      "Requirement already satisfied: defusedxml in /opt/conda/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.7.1)\n",
      "Requirement already satisfied: jupyterlab-pygments in /opt/conda/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.1.2)\n",
      "Requirement already satisfied: nbclient<0.6.0,>=0.5.0 in /opt/conda/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.5.10)\n",
      "Requirement already satisfied: bleach in /opt/conda/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (4.1.0)\n",
      "Requirement already satisfied: pandocfilters>=1.4.1 in /opt/conda/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (1.5.0)\n",
      "Requirement already satisfied: mistune<2,>=0.8.1 in /opt/conda/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.8.4)\n",
      "Requirement already satisfied: webencodings in /opt/conda/lib/python3.8/site-packages (from bleach->nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.5.1)\n",
      "Requirement already satisfied: packaging in /opt/conda/lib/python3.8/site-packages (from bleach->nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (21.3)\n",
      "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /opt/conda/lib/python3.8/site-packages (from packaging->bleach->nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (3.0.6)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Installing collected packages: widgetsnbextension, jupyterlab-widgets, ipywidgets\n",
      "Successfully installed ipywidgets-7.6.5 jupyterlab-widgets-1.0.2 widgetsnbextension-3.5.2\n",
      "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\n",
      "Thu Feb 10 22:01:27 2022       \n",
      "+-----------------------------------------------------------------------------+\n",
      "| NVIDIA-SMI 510.39.01    Driver Version: 510.39.01    CUDA Version: 11.6     |\n",
      "|-------------------------------+----------------------+----------------------+\n",
      "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
      "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
      "|                               |                      |               MIG M. |\n",
      "|===============================+======================+======================|\n",
      "|   0  NVIDIA GeForce ...  On   | 00000000:09:00.0 Off |                  N/A |\n",
      "|  0%   42C    P8    20W / 320W |      0MiB / 10240MiB |      0%      Default |\n",
      "|                               |                      |                  N/A |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "                                                                               \n",
      "+-----------------------------------------------------------------------------+\n",
      "| Processes:                                                                  |\n",
      "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
      "|        ID   ID                                                   Usage      |\n",
      "|=============================================================================|\n",
      "|  No running processes found                                                 |\n",
      "+-----------------------------------------------------------------------------+\n"
     ]
    }
   ],
   "source": [
    "!pip install ipywidgets --trusted-host pypi.org --trusted-host pypi.python.org --trusted-host=files.pythonhosted.org\n",
    "!nvidia-smi"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id=\"2\"></a>\n",
    "## 2. Creating TorchScript modules\n",
    "\n",
    "Here we create two submodules for a feature extractor and a classifier and stitch them together in a single LeNet module. In this case this is overkill but modules give us granular control over our program including where we decide to optimize and where we don't. It is also the unit that the TorchScript compiler operates on. So you can decide to only convert/optimize the feature extractor and leave the classifier in standard PyTorch or you can convert the whole thing. When compiling your module to TorchScript, there are two paths: Tracing and Scripting.  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch \n",
    "from torch import nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "class LeNetFeatExtractor(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(LeNetFeatExtractor, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(1, 128, 3)\n",
    "        self.conv2 = nn.Conv2d(128, 16, 3)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))\n",
    "        x = F.max_pool2d(F.relu(self.conv2(x)), 2)\n",
    "        return x\n",
    "\n",
    "class LeNetClassifier(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(LeNetClassifier, self).__init__()\n",
    "        self.fc1 = nn.Linear(16 * 6 * 6, 120)\n",
    "        self.fc2 = nn.Linear(120, 84)\n",
    "        self.fc3 = nn.Linear(84, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = torch.flatten(x,1)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.relu(self.fc2(x))\n",
    "        x = self.fc3(x)\n",
    "        return x\n",
    "\n",
    "class LeNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(LeNet, self).__init__()\n",
    "        self.feat = LeNetFeatExtractor()\n",
    "        self.classifer = LeNetClassifier()\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.feat(x)\n",
    "        x = self.classifer(x)\n",
    "        return x\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let us define a helper function to benchmark a model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import numpy as np\n",
    "\n",
    "import torch.backends.cudnn as cudnn\n",
    "cudnn.benchmark = True\n",
    "\n",
    "def benchmark(model, input_shape=(1024, 1, 32, 32), dtype='fp32', nwarmup=50, nruns=1000):\n",
    "    input_data = torch.randn(input_shape)\n",
    "    input_data = input_data.to(\"cuda\")\n",
    "    if dtype=='fp16':\n",
    "        input_data = input_data.half()\n",
    "        \n",
    "    print(\"Warm up ...\")\n",
    "    with torch.no_grad():\n",
    "        for _ in range(nwarmup):\n",
    "            features = model(input_data)\n",
    "    torch.cuda.synchronize()\n",
    "    print(\"Start timing ...\")\n",
    "    timings = []\n",
    "    with torch.no_grad():\n",
    "        for i in range(1, nruns+1):\n",
    "            start_time = time.time()\n",
    "            features = model(input_data)\n",
    "            torch.cuda.synchronize()\n",
    "            end_time = time.time()\n",
    "            timings.append(end_time - start_time)\n",
    "            if i%100==0:\n",
    "                print('Iteration %d/%d, ave batch time %.2f ms'%(i, nruns, np.mean(timings)*1000))\n",
    "\n",
    "    print(\"Input shape:\", input_data.size())\n",
    "    print(\"Output features size:\", features.size())\n",
    "    \n",
    "    print('Average batch time: %.2f ms'%(np.mean(timings)*1000))\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### PyTorch model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LeNet(\n",
       "  (feat): LeNetFeatExtractor(\n",
       "    (conv1): Conv2d(1, 128, kernel_size=(3, 3), stride=(1, 1))\n",
       "    (conv2): Conv2d(128, 16, kernel_size=(3, 3), stride=(1, 1))\n",
       "  )\n",
       "  (classifer): LeNetClassifier(\n",
       "    (fc1): Linear(in_features=576, out_features=120, bias=True)\n",
       "    (fc2): Linear(in_features=120, out_features=84, bias=True)\n",
       "    (fc3): Linear(in_features=84, out_features=10, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = LeNet()\n",
    "model.to(\"cuda\").eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Warm up ...\n",
      "Start timing ...\n",
      "Iteration 100/1000, ave batch time 5.56 ms\n",
      "Iteration 200/1000, ave batch time 5.56 ms\n",
      "Iteration 300/1000, ave batch time 5.56 ms\n",
      "Iteration 400/1000, ave batch time 5.56 ms\n",
      "Iteration 500/1000, ave batch time 5.56 ms\n",
      "Iteration 600/1000, ave batch time 5.56 ms\n",
      "Iteration 700/1000, ave batch time 5.56 ms\n",
      "Iteration 800/1000, ave batch time 5.56 ms\n",
      "Iteration 900/1000, ave batch time 5.56 ms\n",
      "Iteration 1000/1000, ave batch time 5.56 ms\n",
      "Input shape: torch.Size([1024, 1, 32, 32])\n",
      "Output features size: torch.Size([1024, 10])\n",
      "Average batch time: 5.56 ms\n"
     ]
    }
   ],
   "source": [
    "benchmark(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "When compiling your module to TorchScript, there are two paths: Tracing and Scripting.  \n",
    " \n",
    "### Tracing\n",
    "\n",
    "Tracing follows the path of execution when the module is called and records what happens. This recording is what the TorchScript IR will describe. To trace an instance of our LeNet module, we can call torch.jit.trace  with an example input. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LeNet(\n",
       "  original_name=LeNet\n",
       "  (feat): LeNetFeatExtractor(\n",
       "    original_name=LeNetFeatExtractor\n",
       "    (conv1): Conv2d(original_name=Conv2d)\n",
       "    (conv2): Conv2d(original_name=Conv2d)\n",
       "  )\n",
       "  (classifer): LeNetClassifier(\n",
       "    original_name=LeNetClassifier\n",
       "    (fc1): Linear(original_name=Linear)\n",
       "    (fc2): Linear(original_name=Linear)\n",
       "    (fc3): Linear(original_name=Linear)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "traced_model = torch.jit.trace(model, torch.empty([1,1,32,32]).to(\"cuda\"))\n",
    "traced_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Warm up ...\n",
      "Start timing ...\n",
      "Iteration 100/1000, ave batch time 5.56 ms\n",
      "Iteration 200/1000, ave batch time 5.56 ms\n",
      "Iteration 300/1000, ave batch time 5.56 ms\n",
      "Iteration 400/1000, ave batch time 5.56 ms\n",
      "Iteration 500/1000, ave batch time 5.56 ms\n",
      "Iteration 600/1000, ave batch time 5.56 ms\n",
      "Iteration 700/1000, ave batch time 5.56 ms\n",
      "Iteration 800/1000, ave batch time 5.56 ms\n",
      "Iteration 900/1000, ave batch time 5.56 ms\n",
      "Iteration 1000/1000, ave batch time 5.56 ms\n",
      "Input shape: torch.Size([1024, 1, 32, 32])\n",
      "Output features size: torch.Size([1024, 10])\n",
      "Average batch time: 5.56 ms\n"
     ]
    }
   ],
   "source": [
    "benchmark(traced_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Scripting\n",
    "\n",
    "Scripting actually inspects your code with a compiler and  generates an equivalent TorchScript program. The difference is that since tracing simply follows the execution of your module, it cannot pick up control flow for instance, it will only follow the code path that a particular input triggers. By working from the Python code, the compiler can include these components. We can run the script compiler on our LeNet  module by calling torch.jit.script.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = LeNet().to(\"cuda\").eval()\n",
    "script_model = torch.jit.script(model)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "RecursiveScriptModule(\n",
       "  original_name=LeNet\n",
       "  (feat): RecursiveScriptModule(\n",
       "    original_name=LeNetFeatExtractor\n",
       "    (conv1): RecursiveScriptModule(original_name=Conv2d)\n",
       "    (conv2): RecursiveScriptModule(original_name=Conv2d)\n",
       "  )\n",
       "  (classifer): RecursiveScriptModule(\n",
       "    original_name=LeNetClassifier\n",
       "    (fc1): RecursiveScriptModule(original_name=Linear)\n",
       "    (fc2): RecursiveScriptModule(original_name=Linear)\n",
       "    (fc3): RecursiveScriptModule(original_name=Linear)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "script_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Warm up ...\n",
      "Start timing ...\n",
      "Iteration 100/1000, ave batch time 5.56 ms\n",
      "Iteration 200/1000, ave batch time 5.56 ms\n",
      "Iteration 300/1000, ave batch time 5.56 ms\n",
      "Iteration 400/1000, ave batch time 5.56 ms\n",
      "Iteration 500/1000, ave batch time 5.56 ms\n",
      "Iteration 600/1000, ave batch time 5.56 ms\n",
      "Iteration 700/1000, ave batch time 5.56 ms\n",
      "Iteration 800/1000, ave batch time 5.56 ms\n",
      "Iteration 900/1000, ave batch time 5.56 ms\n",
      "Iteration 1000/1000, ave batch time 5.56 ms\n",
      "Input shape: torch.Size([1024, 1, 32, 32])\n",
      "Output features size: torch.Size([1024, 10])\n",
      "Average batch time: 5.56 ms\n"
     ]
    }
   ],
   "source": [
    "benchmark(script_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id=\"3\"></a>\n",
    "## 3. Compiling with Torch-TensorRT"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### TorchScript traced model\n",
    "\n",
    "First, we compile the TorchScript traced model with Torch-TensorRT. Notice the performance impact."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: [Torch-TensorRT] - For input x.1, found user specified input dtype as Float16, however when inspecting the graph, the input type expected was inferred to be Float\n",
      "The compiler is going to use the user setting Float16\n",
      "This conflict may cause an error at runtime due to partial compilation being enabled and therefore\n",
      "compatibility with PyTorch's data type convention is required.\n",
      "If you do indeed see errors at runtime either:\n",
      "- Remove the dtype spec for x.1\n",
      "- Disable partial compilation by setting require_full_compilation to True\n",
      "WARNING: [Torch-TensorRT] - Dilation not used in Max pooling converter\n",
      "WARNING: [Torch-TensorRT] - Dilation not used in Max pooling converter\n",
      "WARNING: [Torch-TensorRT TorchScript Conversion Context] - Max value of this profile is not valid\n"
     ]
    }
   ],
   "source": [
    "import torch_tensorrt\n",
    "\n",
    "# We use a batch-size of 1024, and half precision\n",
    "trt_ts_module = torch_tensorrt.compile(traced_model, inputs=[torch_tensorrt.Input(\n",
    "            min_shape=[1024, 1, 32, 32],\n",
    "            opt_shape=[1024, 1, 33, 33],\n",
    "            max_shape=[1024, 1, 34, 34],\n",
    "            dtype=torch.half\n",
    "            )], \n",
    "            enabled_precisions = {torch.half})\n",
    "\n",
    "input_data = torch.randn((1024, 1, 32, 32))\n",
    "input_data = input_data.half().to(\"cuda\")\n",
    "\n",
    "input_data = input_data.half()\n",
    "result = trt_ts_module(input_data)\n",
    "torch.jit.save(trt_ts_module, \"trt_ts_module.ts\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Warm up ...\n",
      "Start timing ...\n",
      "Iteration 100/1000, ave batch time 1.41 ms\n",
      "Iteration 200/1000, ave batch time 1.40 ms\n",
      "Iteration 300/1000, ave batch time 1.40 ms\n",
      "Iteration 400/1000, ave batch time 1.39 ms\n",
      "Iteration 500/1000, ave batch time 1.40 ms\n",
      "Iteration 600/1000, ave batch time 1.40 ms\n",
      "Iteration 700/1000, ave batch time 1.40 ms\n",
      "Iteration 800/1000, ave batch time 1.40 ms\n",
      "Iteration 900/1000, ave batch time 1.40 ms\n",
      "Iteration 1000/1000, ave batch time 1.40 ms\n",
      "Input shape: torch.Size([1024, 1, 32, 32])\n",
      "Output features size: torch.Size([1024, 10])\n",
      "Average batch time: 1.40 ms\n"
     ]
    }
   ],
   "source": [
    "benchmark(trt_ts_module, input_shape=(1024, 1, 32, 32), dtype=\"fp16\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### TorchScript script model\n",
    "\n",
    "Next, we compile the TorchScript script model with Torch-TensorRT. Notice the performance impact."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: [Torch-TensorRT] - For input x.1, found user specified input dtype as Float16, however when inspecting the graph, the input type expected was inferred to be Float\n",
      "The compiler is going to use the user setting Float16\n",
      "This conflict may cause an error at runtime due to partial compilation being enabled and therefore\n",
      "compatibility with PyTorch's data type convention is required.\n",
      "If you do indeed see errors at runtime either:\n",
      "- Remove the dtype spec for x.1\n",
      "- Disable partial compilation by setting require_full_compilation to True\n",
      "WARNING: [Torch-TensorRT] - Dilation not used in Max pooling converter\n",
      "WARNING: [Torch-TensorRT] - Dilation not used in Max pooling converter\n",
      "WARNING: [Torch-TensorRT TorchScript Conversion Context] - Max value of this profile is not valid\n"
     ]
    }
   ],
   "source": [
    "import torch_tensorrt\n",
    "\n",
    "trt_script_module = torch_tensorrt.compile(script_model, inputs = [torch_tensorrt.Input(\n",
    "            min_shape=[1024, 1, 32, 32],\n",
    "            opt_shape=[1024, 1, 33, 33],\n",
    "            max_shape=[1024, 1, 34, 34],\n",
    "            dtype=torch.half\n",
    "            )],\n",
    "            enabled_precisions={torch.half})\n",
    "\n",
    "input_data = torch.randn((1024, 1, 32, 32))\n",
    "input_data = input_data.half().to(\"cuda\")\n",
    "\n",
    "input_data = input_data.half()\n",
    "result = trt_script_module(input_data)\n",
    "torch.jit.save(trt_script_module, \"trt_script_module.ts\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Warm up ...\n",
      "Start timing ...\n",
      "Iteration 100/1000, ave batch time 1.43 ms\n",
      "Iteration 200/1000, ave batch time 1.41 ms\n",
      "Iteration 300/1000, ave batch time 1.40 ms\n",
      "Iteration 400/1000, ave batch time 1.42 ms\n",
      "Iteration 500/1000, ave batch time 1.42 ms\n",
      "Iteration 600/1000, ave batch time 1.41 ms\n",
      "Iteration 700/1000, ave batch time 1.41 ms\n",
      "Iteration 800/1000, ave batch time 1.40 ms\n",
      "Iteration 900/1000, ave batch time 1.40 ms\n",
      "Iteration 1000/1000, ave batch time 1.40 ms\n",
      "Input shape: torch.Size([1024, 1, 32, 32])\n",
      "Output features size: torch.Size([1024, 10])\n",
      "Average batch time: 1.40 ms\n"
     ]
    }
   ],
   "source": [
    "benchmark(trt_script_module, input_shape=(1024, 1, 32, 32), dtype=\"fp16\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Conclusion\n",
    "\n",
    "In this notebook, we have walked through the complete process of compiling TorchScript models with Torch-TensorRT and test the performance impact of the optimization.\n",
    "\n",
    "### What's next\n",
    "Now it's time to try Torch-TensorRT on your own model. Fill out issues at https://github.com/NVIDIA/Torch-TensorRT. Your involvement will help future development of Torch-TensorRT.\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
